import torch

class InputSampler:
    def __init__(self, cfg, *args, **kwargs):
        self.num_classes = kwargs['num_classes']
        self.classes = torch.arange(self.num_classes).long()
        self.batch_size = cfg.Generator.batch_size
        self.noise_size = cfg.Generator.num_latent_feature
        self.noise_scale = cfg.Generator.noise_scale
        self.noise_clamp = cfg.Generator.noise_clamp
        self.fixed_label = cfg.Generator.fixed_label
        self.fixed_noise = cfg.Generator.fixed_noise
        self.double_label = cfg.Generator.double_label

        if self.fixed_label:
            if self.double_label:
                self.y = self.generate_double_y(self.batch_size)
            else:
                self.y = self.generate_y(self.batch_size)

        if self.fixed_noise:
            self.noise = self.generate_noise(self.batch_size)

    def get_classes(self, label_string):
        labels = [int(i) for i in label_string.split(",")]
        return len(labels), torch.tensor(labels).long()

    def generate_noise(self, batch_size):
        noise = torch.randn(batch_size, self.noise_size)*self.noise_scale
        if self.noise_clamp > 0:
            noise = noise.clamp(-self.noise_clamp, self.noise_clamp)
        return noise

    def generate_double_y(self, batch_size):
        # Generate the initial tensor containing integers from 0 to K
        y = torch.arange(self.num_classes)
        sec_y = torch.arange(self.num_classes)

        # Create tensor B with all possible pairwise combinations of y and sec_y
        y, sec_y = torch.meshgrid(y, sec_y)
        B = torch.stack([y.flatten(), sec_y.flatten()], dim=1)
        
        # Remove rows with duplicate elements
        B = B[B[:, 0] != B[:, 1]]

        # Repeat the tensor L times
        repeats = batch_size // (self.num_classes * (self.num_classes - 1)) 
        if batch_size % self.num_classes != 0:
            repeats = repeats + 1
        B = B.repeat(repeats,1)[:batch_size]
        # Shuffle the tensor
        B = B[torch.randperm(B.shape[0])]
        # Convert the data type to long
        B = B.long()
        return torch.stack((self.classes[B[:,0]], self.classes[B[:,1]]), dim = 1) 

    def generate_y(self, batch_size):
        # Generate the initial tensor containing integers from 0 to K
        y = torch.arange(self.num_classes)
        # Repeat the tensor L times
        repeats = batch_size // self.num_classes
        if batch_size % self.num_classes != 0:
            repeats = repeats + 1
        y = y.repeat(repeats)[:batch_size]
        # Shuffle the tensor
        y = y[torch.randperm(y.shape[0])]
        # Convert the data type to long
        y = y.long()
        return self.classes[y]

    def __call__(self, batch_size = None):
        if batch_size is None or batch_size < 1:
            batch_size = self.batch_size

        if self.fixed_noise:
            noise = self.noise
        else:
            noise = self.generate_noise(batch_size)

        if self.fixed_label:
            y = self.y
        else:
            if self.double_label:
                y = self.generate_double_y(batch_size)
            else:
                y = self.generate_y(batch_size)

        return noise, y
